Skip to content

Conversation

@nlpcat
Copy link
Contributor

@nlpcat nlpcat commented Jul 30, 2022

What does this PR do?

support dynamic input for tf.function + generate (XLA). needed for batch tf serving

export:

import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM

class MyOwnModel(tf.Module):
    def __init__(self, model_path="t5-small"):
        super(MyOwnModel, self).__init__()
        self.model = TFAutoModelForSeq2SeqLM.from_pretrained(model_path)

    @tf.function(input_signature=(tf.TensorSpec((None, 32), tf.int32, name="input_ids"),
                                  tf.TensorSpec((None, 32), tf.int32, name="attention_mask")), jit_compile=True)
    def serving(self, input_ids, attention_mask):
        outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=32,
                                   return_dict_in_generate=True)
        return {"sequences": outputs["sequences"]}

model = MyOwnModel()
export_dir = "./"
tf.saved_model.save(
    model,
    export_dir,
    signatures={
        "serving_default":
            model.serving
    })

tf model run

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
export_dir = "./"
model = tf.saved_model.load(export_dir)

tokenizer = AutoTokenizer.from_pretrained("t5-small")
tokenization_kwargs = {"pad_to_multiple_of": 32, "padding": True, "return_tensors": "tf"}

input_prompts = [
    f"translate English to {language}: I have four cats and three dogs."
    for language in ["German", "French", "Romanian"]
]

def generate_text(inputs):
    tokenized_inputs = tokenizer(inputs, **tokenization_kwargs)
    generated_texts = model.signatures["serving_default"](**tokenized_inputs)
    for text in generated_texts["sequences"]:
        print(tokenizer.decode(text, skip_special_tokens=True))
# The first prompt will be slow (compiling), the others will be very fast!
generate_text(input_prompts[:2])
generate_text(input_prompts[:3])

xla_run

import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = TFAutoModelForSeq2SeqLM.from_pretrained("t5-small")

# Main changes with respect to the original generate workflow: `tf.function` and `pad_to_multiple_of`
xla_generate = tf.function(model.generate, jit_compile=True)
tokenization_kwargs = {"pad_to_multiple_of": 32, "padding": True, "return_tensors": "tf"}

# The first prompt will be slow (compiling), the others will be very fast!
input_prompts = [
    f"translate English to {language}: I have four cats and three dogs."
    for language in ["German", "French", "Romanian"]
]
tokenized_inputs = tokenizer(input_prompts, **tokenization_kwargs)
generated_texts = xla_generate(**tokenized_inputs, max_new_tokens=32)
for text in generated_texts:
    print(tokenizer.decode(text, skip_special_tokens=True))

this also works for beam search by changing exported code as

def serving(self, input_ids, attention_mask):
        outputs = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=32,
                                   return_dict_in_generate=True, num_beams=3, num_return_sequences=3)
        return {"sequences": outputs["sequences"]}

Fixes #18357
Fixes #16823

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

cc @gante @patrickvonplaten

@nlpcat nlpcat marked this pull request as ready for review July 30, 2022 01:58
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 30, 2022

The documentation is not available anymore as the PR was closed or merged.

@nlpcat nlpcat force-pushed the fix.generate.batch branch from d4b873f to 59e80ca Compare July 30, 2022 06:18
@nlpcat nlpcat changed the title change shape to support dynamic batch input in tf.function XLA generate change shape to support dynamic batch input in tf.function XLA generate for tf serving Aug 1, 2022
@nlpcat
Copy link
Contributor Author

nlpcat commented Aug 3, 2022

Cc @gante @patrickvonplaten

@gante
Copy link
Contributor

gante commented Aug 3, 2022

Hi @nlpcat 👋 I see the change is needed because an unknown batch size is specified (hence the need for dynamic shapes). I'm going to double-check a few cases against this branch and, if all goes well, I may propose a few changes.

In general, I'm in favor of adding the change, thank you for the PR :)

Copy link
Contributor

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this contribution! (I've double-checked that it doesn't affect the performance of generate)

One bit is missing, if you're up to it -- a test to ensure we don't lose this feature. The best place would probably be UtilsFunctionsTest inside test_modeling_tf_common.py, and the test could be a copy of the example you shared in the PR description.

Let us know if you'd rather have us adding the test instead :)

@gante gante requested a review from sgugger August 3, 2022 16:46
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, adding a test would be nice.
Thanks a lot for your PR!

@gante
Copy link
Contributor

gante commented Aug 3, 2022

(edited the PR header to link more issues this PR fixes :) )

@nlpcat nlpcat force-pushed the fix.generate.batch branch from 59e80ca to 596ecf4 Compare August 4, 2022 07:42
@nlpcat
Copy link
Contributor Author

nlpcat commented Aug 4, 2022

@gante @sgugger i have added the test . 596ecf4.
Can you help review and merge this PR if it looks good? Thanks.

@gante
Copy link
Contributor

gante commented Aug 4, 2022

@nlpcat this is fantastic! Thank you so much for your contribution 🙏

@gante gante merged commit fc1d841 into huggingface:main Aug 4, 2022
@s4sarath
Copy link

The whole idea of Tensorflow in Huggingface is very complicated and a pain.
@nlpcat - You better look into

https://github.com/legacyai/tf-transformers/blob/main/docs/source/model_usage/text_generation_using_t5.ipynb

@rafaellemay
Copy link

I was testing this code, but I have found an issue with my model: I think the file tf_logits_process.py, also needs to use the shape_list function to support dynamic batch input.

@gante
Copy link
Contributor

gante commented Dec 15, 2022

@rafaellemay can you open an issue with the problem that you found (and a snippet containing an example)? It would help us ensure the library works well in all cases :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

generate with tf.function (xla) not working for tf model export (TF) model.generate to tf.function for tf serving

6 participants